/**
 * Copyright 2021-2024, XGBoost Contributors
 */
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_

#include <algorithm>  // for copy
#include <cstddef>    // for size_t
#include <limits>     // for numeric_limits
#include <memory>     // for shared_ptr
#include <numeric>    // for accumulate
#include <utility>    // for move
#include <vector>     // for vector

#include "../../collective/allgather.h"
#include "../../common/categorical.h"  // for CatBitField
#include "../../common/hist_util.h"    // for GHistRow, HistogramCuts
#include "../../common/linalg_op.h"    // for cbegin, cend, begin
#include "../../common/random.h"       // for ColumnSampler
#include "../constraints.h"            // for FeatureInteractionConstraintHost
#include "../param.h"                  // for TrainParam
#include "../split_evaluator.h"        // for TreeEvaluator
#include "expand_entry.h"              // for MultiExpandEntry
#include "hist_cache.h"                // for BoundedHistCollection
#include "xgboost/base.h"              // for bst_node_t, bst_target_t, bst_feature_t
#include "xgboost/context.h"           // for COntext
#include "xgboost/linalg.h"            // for Constants, Vector

namespace xgboost::tree {
/**
 * @brief Gather the expand entries from all the workers.
 * @param entries Local expand entries on this worker.
 * @return Global expand entries gathered from all workers.
 */
template <typename ExpandEntry>
std::enable_if_t<std::is_same_v<ExpandEntry, CPUExpandEntry> ||
                     std::is_same_v<ExpandEntry, MultiExpandEntry>,
                 std::vector<ExpandEntry>>
AllgatherColumnSplit(Context const *ctx, std::vector<ExpandEntry> const &entries) {
  auto const n_entries = entries.size();

  // First, gather all the primitive fields.
  std::vector<ExpandEntry> local_entries(n_entries);

  // Collect and serialize all entries
  std::vector<std::vector<char>> serialized_entries;
  for (std::size_t i = 0; i < n_entries; ++i) {
    Json jentry{Object{}};
    entries[i].Save(&jentry);

    std::vector<char> out;
    Json::Dump(jentry, &out, std::ios::binary);

    serialized_entries.emplace_back(std::move(out));
  }
  auto all_serialized = collective::VectorAllgatherV(ctx, serialized_entries);
  CHECK_GE(all_serialized.size(), local_entries.size());

  std::vector<ExpandEntry> all_entries(all_serialized.size());
  std::transform(all_serialized.cbegin(), all_serialized.cend(), all_entries.begin(),
                 [](std::vector<char> const &e) {
                   ExpandEntry entry;
                   auto je = Json::Load(StringView{e.data(), e.size()}, std::ios::binary);
                   entry.Load(je);
                   return entry;
                 });

  return all_entries;
}

class HistEvaluator {
 private:
  struct NodeEntry {
    /*! \brief statics for node entry */
    GradStats stats;
    /*! \brief loss of this node, without split */
    bst_float root_gain{0.0f};
  };

 private:
  Context const *ctx_;
  TrainParam const *param_;
  std::shared_ptr<common::ColumnSampler> column_sampler_;
  TreeEvaluator tree_evaluator_;
  bool is_col_split_{false};
  FeatureInteractionConstraintHost interaction_constraints_;
  std::vector<NodeEntry> snode_;

  // if sum of statistics for non-missing values in the node
  // is equal to sum of statistics for all values:
  // then - there are no missing values
  // else - there are missing values
  bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
    if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
      return false;
    } else {
      return true;
    }
  }

  [[nodiscard]] bool IsValid(GradStats const &left, GradStats const &right) const {
    return left.GetHess() >= param_->min_child_weight &&
           right.GetHess() >= param_->min_child_weight;
  }

  /**
   * \brief Use learned direction with one-hot split. Other implementations (LGB) create a
   *        pseudo-category for missing value but here we just do a complete scan to avoid
   *        making specialized histogram bin.
   */
  void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist,
                       bst_feature_t fidx, bst_node_t nidx,
                       TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
                       SplitEntry *p_best) const {
    const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
    const std::vector<bst_float> &cut_val = cut.Values();

    bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
    bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
    bst_bin_t n_bins = iend - ibegin;

    GradStats left_sum;
    GradStats right_sum;
    // best split so far
    SplitEntry best;
    best.is_cat = false;  // marker for whether it's updated or not.

    auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
    auto feature_sum = GradStats{
        std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
    GradStats missing;
    auto const &parent = snode_[nidx];
    missing.SetSubstract(parent.stats, feature_sum);

    for (bst_bin_t i = ibegin; i != iend; i += 1) {
      auto split_pt = cut_val[i];

      // missing on left (treat missing as other categories)
      right_sum = GradStats{hist[i]};
      left_sum.SetSubstract(parent.stats, right_sum);
      if (IsValid(left_sum, right_sum)) {
        auto missing_left_chg =
            static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
                                                       GradStats{right_sum}) -
                               parent.root_gain);
        best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
      }

      // missing on right (treat missing as chosen category)
      right_sum.Add(missing);
      left_sum.SetSubstract(parent.stats, right_sum);
      if (IsValid(left_sum, right_sum)) {
        auto missing_right_chg =
            static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
                                                       GradStats{right_sum}) -
                               parent.root_gain);
        best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
      }
    }

    if (best.is_cat) {
      auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
      best.cat_bits.resize(n, 0);
      common::CatBitField cat_bits{best.cat_bits};
      cat_bits.Set(best.split_value);
    }

    p_best->Update(best);
  }

  /**
   * \brief Enumerate with partition-based splits.
   *
   * The implementation is different from LightGBM. Firstly we don't have a
   * pseudo-cateogry for missing value, instead of we make 2 complete scans over the
   * histogram. Secondly, both scan directions generate splits in the same
   * order. Following table depicts the scan process, square bracket means the gradient in
   * missing values is resided on that partition:
   *
   *   | Forward  | Backward |
   *   |----------+----------|
   *   | [BCDE] A | E [ABCD] |
   *   | [CDE] AB | DE [ABC] |
   *   | [DE] ABC | CDE [AB] |
   *   | [E] ABCD | BCDE [A] |
   */
  template <int d_step>
  void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
                     common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx,
                     TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
                     SplitEntry *p_best) {
    static_assert(d_step == +1 || d_step == -1, "Invalid step.");

    auto const &cut_ptr = cut.Ptrs();
    auto const &cut_val = cut.Values();
    auto const &parent = snode_[nidx];

    bst_bin_t f_begin = cut_ptr[fidx];
    bst_bin_t f_end = cut_ptr[fidx + 1];
    bst_bin_t n_bins_feature{f_end - f_begin};
    auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);

    // statistics on both sides of split
    GradStats left_sum;
    GradStats right_sum;
    // best split so far
    SplitEntry best;

    auto f_hist = hist.subspan(f_begin, n_bins_feature);
    bst_bin_t it_begin, it_end;
    if (d_step > 0) {
      it_begin = f_begin;
      it_end = it_begin + n_bins - 1;
    } else {
      it_begin = f_end - 1;
      it_end = it_begin - n_bins + 1;
    }

    bst_bin_t best_thresh{-1};
    for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
      auto j = i - f_begin;  // index local to current feature
      if (d_step == 1) {
        right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
        left_sum.SetSubstract(parent.stats, right_sum);  // missing on left
      } else {
        left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
        right_sum.SetSubstract(parent.stats, left_sum);  // missing on right
      }
      if (IsValid(left_sum, right_sum)) {
        auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
                                                GradStats{right_sum}) -
                        parent.root_gain;
        // We don't have a numeric split point, nan here is a dummy split.
        if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
                        left_sum, right_sum)) {
          best_thresh = i;
        }
      }
    }

    if (best_thresh != -1) {
      auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
      best.cat_bits = decltype(best.cat_bits)(n, 0);
      common::CatBitField cat_bits{best.cat_bits};
      bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
      CHECK_GT(partition, 0);
      std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](std::size_t c) {
        auto cat = cut_val[c + f_begin];
        cat_bits.Set(cat);
      });
    }

    p_best->Update(best);
  }

  // Enumerate/Scan the split values of specific feature
  // Returns the sum of gradients corresponding to the data points that contains
  // a non-missing value for the particular feature fid.
  template <int d_step>
  GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist,
                           bst_feature_t fidx, bst_node_t nidx,
                           TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
                           SplitEntry *p_best) const {
    static_assert(d_step == +1 || d_step == -1, "Invalid step.");

    // aliases
    const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
    const std::vector<bst_float> &cut_val = cut.Values();
    auto const &parent = snode_[nidx];

    // statistics on both sides of split
    GradStats left_sum;
    GradStats right_sum;
    // best split so far
    SplitEntry best;

    // bin boundaries
    CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
    CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
    // imin: index (offset) of the minimum value for feature fid need this for backward
    //       enumeration
    const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
    // ibegin, iend: smallest/largest cut points for feature fid use int to allow for
    // value -1
    bst_bin_t ibegin, iend;
    if (d_step > 0) {
      ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
      iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
    } else {
      ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
      iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
    }

    for (bst_bin_t i = ibegin; i != iend; i += d_step) {
      // start working
      // try to find a split
      left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
      right_sum.SetSubstract(parent.stats, left_sum);
      if (IsValid(left_sum, right_sum)) {
        bst_float loss_chg;
        bst_float split_pt;
        if (d_step > 0) {
          // forward enumeration: split at right bound of each bin
          loss_chg =
              static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
                                                         GradStats{right_sum}) -
                                 parent.root_gain);
          split_pt = cut_val[i];  // not used for partition based
          best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
        } else {
          // backward enumeration: split at left bound of each bin
          loss_chg =
              static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
                                                         GradStats{left_sum}) -
                                 parent.root_gain);
          if (i == imin) {
            split_pt = cut.MinValues()[fidx];
          } else {
            split_pt = cut_val[i - 1];
          }
          best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
        }
      }
    }

    p_best->Update(best);
    return left_sum;
  }

 public:
  void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
                      common::Span<FeatureType const> feature_types, const RegTree &tree,
                      std::vector<CPUExpandEntry> *p_entries) {
    auto n_threads = ctx_->Threads();
    auto &entries = *p_entries;
    // All nodes are on the same level, so we can store the shared ptr.
    std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
    for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
      auto nidx = entries[nidx_in_set].nid;
      features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
    }
    CHECK(!features.empty());
    const size_t grain_size = std::max<size_t>(1, features.front()->Size() / n_threads);
    common::BlockedSpace2d space(
        entries.size(), [&](size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
        grain_size);

    std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
    for (size_t i = 0; i < entries.size(); ++i) {
      for (decltype(n_threads) j = 0; j < n_threads; ++j) {
        tloc_candidates[i * n_threads + j] = entries[i];
      }
    }
    auto evaluator = tree_evaluator_.GetEvaluator();
    auto const &cut_ptrs = cut.Ptrs();

    common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
      auto tidx = omp_get_thread_num();
      auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
      auto best = &entry->split;
      auto nidx = entry->nid;
      auto histogram = hist[nidx];
      auto features_set = features[nidx_in_set]->ConstHostSpan();
      for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
        auto fidx = features_set[fidx_in_set];
        bool is_cat = common::IsCat(feature_types, fidx);
        if (!interaction_constraints_.Query(nidx, fidx)) {
          continue;
        }
        if (is_cat) {
          auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
          if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
            EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
          } else {
            std::vector<size_t> sorted_idx(n_bins);
            std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
            auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
            // Sort the histogram to get contiguous partitions.
            std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
              auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
                         evaluator.CalcWeightCat(*param_, feat_hist[r]);
              return ret;
            });
            EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
            EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
          }
        } else {
          auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
          if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
            EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
          }
        }
      }
    });

    for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
      for (auto tidx = 0; tidx < n_threads; ++tidx) {
        entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
      }
    }

    if (is_col_split_) {
      // With column-wise data split, we gather the best splits from all the workers and update the
      // expand entries accordingly.
      auto all_entries = AllgatherColumnSplit(ctx_, entries);
      for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
        for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
          entries[nidx_in_set].split.Update(
              all_entries[worker * entries.size() + nidx_in_set].split);
        }
      }
    }
  }

  // Add splits to tree, handles all statistic
  void ApplyTreeSplit(CPUExpandEntry const &candidate, RegTree *p_tree) {
    auto evaluator = tree_evaluator_.GetEvaluator();
    RegTree &tree = *p_tree;

    GradStats parent_sum = candidate.split.left_sum;
    parent_sum.Add(candidate.split.right_sum);
    auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
    auto left_weight =
        evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
    auto right_weight =
        evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});

    if (candidate.split.is_cat) {
      tree.ExpandCategorical(
          candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
          candidate.split.DefaultLeft(), base_weight, left_weight * param_->learning_rate,
          right_weight * param_->learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
          candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
    } else {
      tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
                      candidate.split.DefaultLeft(), base_weight,
                      left_weight * param_->learning_rate, right_weight * param_->learning_rate,
                      candidate.split.loss_chg, parent_sum.GetHess(),
                      candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
    }

    // Set up child constraints
    auto left_child = tree[candidate.nid].LeftChild();
    auto right_child = tree[candidate.nid].RightChild();
    tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
                             tree[candidate.nid].SplitIndex(), left_weight, right_weight);
    evaluator = tree_evaluator_.GetEvaluator();

    snode_.resize(tree.GetNodes().size());
    snode_.at(left_child).stats = candidate.split.left_sum;
    snode_.at(left_child).root_gain =
        evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
    snode_.at(right_child).stats = candidate.split.right_sum;
    snode_.at(right_child).root_gain =
        evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});

    interaction_constraints_.Split(candidate.nid, tree[candidate.nid].SplitIndex(), left_child,
                                   right_child);
  }

  [[nodiscard]] auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
  [[nodiscard]] auto const &Stats() const { return snode_; }

  float InitRoot(GradStats const &root_sum) {
    snode_.resize(1);
    auto root_evaluator = tree_evaluator_.GetEvaluator();

    snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
    snode_[0].root_gain =
        root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
    auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
    return weight;
  }

 public:
  // The column sampler must be constructed by caller since we need to preserve the rng
  // for the entire training session.
  explicit HistEvaluator(Context const *ctx, TrainParam const *param, MetaInfo const &info,
                         std::shared_ptr<common::ColumnSampler> sampler)
      : ctx_{ctx},
        param_{param},
        column_sampler_{std::move(sampler)},
        tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
        is_col_split_{info.IsColumnSplit()} {
    interaction_constraints_.Configure(*param, info.num_col_);
    column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
                          param_->colsample_bynode, param_->colsample_bylevel,
                          param_->colsample_bytree);
  }
};

class HistMultiEvaluator {
  std::vector<double> gain_;
  linalg::Matrix<GradientPairPrecise> stats_;
  TrainParam const *param_;
  FeatureInteractionConstraintHost interaction_constraints_;
  std::shared_ptr<common::ColumnSampler> column_sampler_;
  Context const *ctx_;
  bool is_col_split_{false};

 private:
  static double MultiCalcSplitGain(TrainParam const &param,
                                   linalg::VectorView<GradientPairPrecise const> left_sum,
                                   linalg::VectorView<GradientPairPrecise const> right_sum,
                                   linalg::VectorView<float> left_weight,
                                   linalg::VectorView<float> right_weight) {
    CalcWeight(param, left_sum, left_weight);
    CalcWeight(param, right_sum, right_weight);

    auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
    auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
    return left_gain + right_gain;
  }

  template <bst_bin_t d_step>
  bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
                      common::Span<common::ConstGHistRow> hist,
                      linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
                      SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
    auto const &cut_ptr = cut.Ptrs();
    auto const &cut_val = cut.Values();
    auto const &min_val = cut.MinValues();

    auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
    auto left_sum = sum.Slice(0, linalg::All());
    auto right_sum = sum.Slice(1, linalg::All());

    bst_bin_t ibegin, iend;
    if (d_step > 0) {
      ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
      iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
    } else {
      ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
      iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
    }
    const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);

    auto n_targets = hist.size();
    auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
    auto left_weight = weight.Slice(0, linalg::All());
    auto right_weight = weight.Slice(1, linalg::All());

    for (bst_bin_t i = ibegin; i != iend; i += d_step) {
      for (bst_target_t t = 0; t < n_targets; ++t) {
        auto t_hist = hist[t];
        auto t_p = parent_sum(t);
        left_sum(t) += t_hist[i];
        right_sum(t) = t_p - left_sum(t);
      }

      if (d_step > 0) {
        auto split_pt = cut_val[i];
        auto loss_chg =
            MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
            parent_gain;
        p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
      } else {
        float split_pt;
        if (i == imin) {
          split_pt = min_val[fidx];
        } else {
          split_pt = cut_val[i - 1];
        }
        auto loss_chg =
            MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
            parent_gain;
        p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
      }
    }
    // return true if there's missing. Doesn't handle floating-point error well.
    if (d_step == +1) {
      return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
                         linalg::cbegin(parent_sum));
    }
    return false;
  }

 public:
  void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
                      common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
    auto &entries = *p_entries;
    std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());

    for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
      auto nidx = entries[nidx_in_set].nid;
      features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
    }
    CHECK(!features.empty());

    std::int32_t n_threads = ctx_->Threads();
    std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
    common::BlockedSpace2d space(
        entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
        grain_size);

    std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
    for (std::size_t i = 0; i < entries.size(); ++i) {
      for (std::int32_t j = 0; j < n_threads; ++j) {
        tloc_candidates[i * n_threads + j] = entries[i];
      }
    }
    common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
      auto tidx = omp_get_thread_num();
      auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
      auto best = &entry->split;
      auto parent_sum = stats_.Slice(entry->nid, linalg::All());
      std::vector<common::ConstGHistRow> node_hist;
      for (auto t_hist : hist) {
        node_hist.emplace_back((*t_hist)[entry->nid]);
      }
      auto features_set = features[nidx_in_set]->ConstHostSpan();

      for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
        auto fidx = features_set[fidx_in_set];
        if (!interaction_constraints_.Query(entry->nid, fidx)) {
          continue;
        }
        auto parent_gain = gain_[entry->nid];
        bool missing =
            this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
        if (missing) {
          this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
        }
      }
    });

    for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
      for (auto tidx = 0; tidx < n_threads; ++tidx) {
        entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
      }
    }

    if (is_col_split_) {
      // With column-wise data split, we gather the best splits from all the workers and update the
      // expand entries accordingly.
      auto all_entries = AllgatherColumnSplit(ctx_, entries);
      for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
        for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
          entries[nidx_in_set].split.Update(
              all_entries[worker * entries.size() + nidx_in_set].split);
        }
      }
    }
  }

  linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
    auto n_targets = root_sum.Size();
    stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
    gain_.resize(1);

    linalg::Vector<float> weight({n_targets}, ctx_->Device());
    CalcWeight(*param_, root_sum, weight.HostView());
    auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
    gain_.front() = root_gain;

    auto h_stats = stats_.HostView();
    std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));

    return weight;
  }

  void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
    auto n_targets = p_tree->NumTargets();
    auto parent_sum = stats_.Slice(candidate.nid, linalg::All());

    auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
    auto base_weight = weight.Slice(0, linalg::All());
    CalcWeight(*param_, parent_sum, base_weight);

    auto left_weight = weight.Slice(1, linalg::All());
    auto left_sum =
        linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
    CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);

    auto right_weight = weight.Slice(2, linalg::All());
    auto right_sum =
        linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
    CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);

    p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
                       candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
    CHECK(p_tree->IsMultiTarget());
    auto left_child = p_tree->LeftChild(candidate.nid);
    CHECK_GT(left_child, candidate.nid);
    auto right_child = p_tree->RightChild(candidate.nid);
    CHECK_GT(right_child, candidate.nid);

    std::size_t n_nodes = p_tree->Size();
    gain_.resize(n_nodes);
    // Re-calculate weight without learning rate.
    CalcWeight(*param_, left_sum, left_weight);
    CalcWeight(*param_, right_sum, right_weight);
    gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
    gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);

    if (n_nodes >= stats_.Shape(0)) {
      stats_.Reshape(n_nodes * 2, stats_.Shape(1));
    }
    CHECK_EQ(stats_.Shape(1), n_targets);
    auto left_sum_stat = stats_.Slice(left_child, linalg::All());
    std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
              linalg::begin(left_sum_stat));
    auto right_sum_stat = stats_.Slice(right_child, linalg::All());
    std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
              linalg::begin(right_sum_stat));
  }

  explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
                              std::shared_ptr<common::ColumnSampler> sampler)
      : param_{param},
        column_sampler_{std::move(sampler)},
        ctx_{ctx},
        is_col_split_{info.IsColumnSplit()} {
    interaction_constraints_.Configure(*param, info.num_col_);
    column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
                          param_->colsample_bynode, param_->colsample_bylevel,
                          param_->colsample_bytree);
  }
};

/**
 * \brief CPU implementation of update prediction cache, which calculates the leaf value
 *        for the last tree and accumulates it to prediction vector.
 *
 * \param p_last_tree The last tree being updated by tree updater
 */
template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
                               std::vector<Partitioner> const &partitioner,
                               linalg::VectorView<float> out_preds) {
  auto const &tree = *p_last_tree;
  CHECK(out_preds.Device().IsCPU());
  size_t n_nodes = p_last_tree->GetNodes().size();
  for (auto &part : partitioner) {
    CHECK_EQ(part.Size(), n_nodes);
    common::BlockedSpace2d space(
        part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
    common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
      if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
        auto const &rowset = part[nidx];
        auto leaf_value = tree[nidx].LeafValue();
        for (auto const *it = rowset.begin() + r.begin(); it < rowset.begin() + r.end(); ++it) {
          out_preds(*it) += leaf_value;
        }
      }
    });
  }
}

template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
                               std::vector<Partitioner> const &partitioner,
                               linalg::MatrixView<float> out_preds) {
  CHECK_GT(out_preds.Size(), 0U);
  CHECK(p_last_tree);

  auto const &tree = *p_last_tree;
  if (!tree.IsMultiTarget()) {
    UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0));
    return;
  }

  auto const *mttree = tree.GetMultiTargetTree();
  auto n_nodes = mttree->Size();
  auto n_targets = tree.NumTargets();
  CHECK_EQ(out_preds.Shape(1), n_targets);
  CHECK(out_preds.Device().IsCPU());

  for (auto &part : partitioner) {
    CHECK_EQ(part.Size(), n_nodes);
    common::BlockedSpace2d space(
        part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
    common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
      if (tree.IsLeaf(nidx)) {
        auto const &rowset = part[nidx];
        auto leaf_value = mttree->LeafValue(nidx);
        for (bst_idx_t const *it = rowset.begin() + r.begin(); it < rowset.begin() + r.end();
             ++it) {
          for (std::size_t i = 0; i < n_targets; ++i) {
            out_preds(*it, i) += leaf_value(i);
          }
        }
      }
    });
  }
}
}  // namespace xgboost::tree
#endif  // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
